from scalevi.nn import nn
from scalevi.nn import initializers
import jax
import jax.numpy as np
import scalevi.distributions as dists 
from typing import Any, Callable, Sequence, Optional, Tuple, Union

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any  # this could be a real type?
Array = Any

class Encoder(nn.Module):

    fcn_features: Sequence[int]
    encoding_features: int
    bl_features: int
    keep_sum_stats: bool
    split: int
    scale_transform: Any

    @nn.compact
    def __call__(self, x, θ):
        name = "Encoder" if self.name is None else self.name
        
        x = nn.FCN(
                self.fcn_features+[self.encoding_features],
                name=f"{name}_FCN")(x)
        x = nn.Stats(
                    keep_sum_stats = self.keep_sum_stats,
                    name=f"{name}_Stats")(x)
        x = nn.Bilinear(
                    self.bl_features,
                    name=f"{name}_Bilinear" )(x, np.append(1, θ))
        return {
                    "loc" : x[:self.split], 
                    "scale_tril": self.scale_transform.forward(
                                        dists.util.vec_to_tril_matrix(
                                                            x[self.split:]))
                }

class MaskedEncoder(Encoder):

    @nn.compact
    def __call__(self, x, θ, mask):
        name = "MaskedEncoder" if self.name is None else self.name
        x = nn.FCN(
                self.fcn_features+[self.encoding_features],
                name=f"{name}_FCN")(x)
        x = nn.MaskedStats(
                    keep_sum_stats = self.keep_sum_stats,
                    name=f"{name}_MaskedStats")(x, mask)
        x = nn.Bilinear(
                    self.bl_features,
                    name=f"{name}_Bilinear")(x, np.append(1, θ))
        return {"loc" : x[:self.split], 
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                                        x[self.split:]))
                }

class MaskedEncoder_vGaussian(nn.Module):
    fcn_1_features: Sequence[int]
    fcn_2_features: Sequence[int]
    encoding_features: int
    keep_sum_stats: bool
    D_z: int
    D_θ: int
    scale_transform: Any
    encode_θ: bool = False

    @nn.compact
    def __call__(self, x, θ, mask):
        name = "MaskedEncoder_vGauss" if self.name is None else self.name
        x = nn.FCN(
                self.fcn_1_features+[self.encoding_features],
                name=f"{name}_FCN_1")(x)
        x = nn.MaskedStats(
                    keep_sum_stats = self.keep_sum_stats,
                    name=f"{name}_MaskedStats")(x, mask)
        if self.encode_θ:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(np.append(x, θ))
        else:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(x)

        x = nn.Dense(features = (self.D_z*self.D_θ
                                + self.D_z
                                + self.D_z*(self.D_z+1)//2),
                    use_bias = False,
                    name = f"{name}_Linear", 
                    kernel_init = initializers.normal(0.001))(x)
        return {"loc": np.dot(
                            x[:self.D_z*self.D_θ].reshape(
                                    self.D_z, self.D_θ),
                            θ)
                        + x[self.D_z*self.D_θ: self.D_z*(self.D_θ+1)], 
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                        x[self.D_z*(self.D_θ+1):]))
                }

class MaskedEncoder_vGaussian_vSynthetic(nn.Module):
    fcn_1_features: Sequence[int]
    fcn_2_features: Sequence[int]
    encoding_features: int
    keep_sum_stats: bool
    D_z: int
    D_θ: int
    scale_transform: Any
    encode_θ: bool = False

    @nn.compact
    def __call__(self, x, θ, mask):
        name = "MaskedEncoder_vGauss" if self.name is None else self.name
        x = nn.FCN(
                self.fcn_1_features+[self.encoding_features],
                name=f"{name}_FCN_1")(x)
        x = nn.MaskedStats(
                    keep_sum_stats = self.keep_sum_stats,
                    name=f"{name}_MaskedStats")(x, mask)
        if self.encode_θ:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(np.append(x, θ))
        else:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(x)

        x = nn.Dense(features = (self.D_z*self.D_θ
                                + self.D_z
                                + self.D_z*(self.D_z+1)//2),
                    use_bias = False,
                    name = f"{name}_Linear", 
                    kernel_init = initializers.normal(0.001))(x)
        return {"loc": np.dot(
                            x[:self.D_z*self.D_θ].reshape(
                                    self.D_z, self.D_θ),
                            θ)
                        + x[self.D_z*self.D_θ: self.D_z*(self.D_θ+1)], 
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                        x[self.D_z*(self.D_θ+1):]))
                }

class MaskedEncoder_vBlockGaussian(nn.Module):
    fcn_1_features: Sequence[int]
    fcn_2_features: Sequence[int]
    encoding_features: int
    keep_sum_stats: bool
    D_z: int
    D_θ: int
    scale_transform: Any
    encode_θ: bool = False


    @nn.compact
    def __call__(self, x, θ, mask):
        name = "MaskedEncoder_vBlock" if self.name is None else self.name
        x = nn.FCN(
                self.fcn_1_features+[self.encoding_features],
                name=f"{name}_FCN_1")(x)
        x = nn.MaskedStats(
                    keep_sum_stats = self.keep_sum_stats,
                    name=f"{name}_MaskedStats")(x, mask)
        if self.encode_θ:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(np.append(x, θ))
        else:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(x)

        x = nn.Dense(features = (self.D_z
                                + self.D_z*(self.D_z+1)//2),
                    use_bias = False,
                    name = f"{name}_Linear", 
                    kernel_init = initializers.normal(0.001))(x)
        return {"loc": x[: self.D_z], 
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                        x[self.D_z:]))
                }

class MaskedEncoder_vDiagonal(nn.Module):
    fcn_1_features: Sequence[int]
    fcn_2_features: Sequence[int]
    encoding_features: int
    keep_sum_stats: bool
    D_z: int
    D_θ: int
    scale_transform: Any
    encode_θ: bool = False


    @nn.compact
    def __call__(self, x, θ, mask):
        name = "MaskedEncoder_vDiag" if self.name is None else self.name
        x = nn.FCN(
                self.fcn_1_features+[self.encoding_features],
                name=f"{name}_FCN_1")(x)
        x = nn.MaskedStats(
                    keep_sum_stats = self.keep_sum_stats,
                    name=f"{name}_MaskedStats")(x, mask)
        if self.encode_θ:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(np.append(x, θ))
        else:
                x = nn.FCN(
                        self.fcn_2_features,
                        name=f"{name}_FCN_2")(x)

        x = nn.Dense(features = (self.D_z
                                + self.D_z),
                    use_bias = False,
                    name = f"{name}_Linear", 
                    kernel_init = initializers.normal(0.001))(x)
        return {"mu": x[: self.D_z], 
                "sig": self.scale_transform.forward_diag_transform(x[self.D_z:])
                }

class Bilinear(nn.Module):
    bl_features: int
    split: int
    scale_transform: Any
    
    @nn.compact
    def __call__(self, x, θ):
        name = "Bilinear" if self.name is None else self.name
        x = nn.Bilinear(
                    self.bl_features,
                    name=f"{name}_Bilinear" )(x, np.append(1, θ))
        return {
                    "loc" : x[:self.split], 
                    "scale_tril": self.scale_transform.forward(
                                        dists.util.vec_to_tril_matrix(
                                                            x[self.split:]))
                }
        # return np.split(x, [self.split,], -1)

class BilinearMean(nn.Module):
    bl_features: int
    features: int
    scale_transform: Any

    @nn.compact
    def __call__(self, x, θ):
        name = "BilinearMean" if self.name is None else self.name
        return {
                "loc" : nn.Bilinear(
                            self.bl_features,
                            name=f"{name}_Bilinear")(x, np.append(1, θ)), 

                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                            nn.Dense(
                                                self.features,
                                                use_bias = False,
                                                name=f"{name}_Dense")(x)))
                }

